{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Example of using our Proposed Spatio-Temporal Graph Neural Networks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from stgraph_trainer.datasets import load_province_temporal_data\n",
    "from stgraph_trainer.datasets import load_province_coordinates\n",
    "from stgraph_trainer.datasets import preprocess_data_for_stgnn\n",
    "from stgraph_trainer.utils import PairDataset\n",
    "from stgraph_trainer.utils import compute_metrics\n",
    "from stgraph_trainer.utils import get_distance_in_km_between_earth_coordinates\n",
    "from stgraph_trainer.utils import get_adjacency_matrix\n",
    "from stgraph_trainer.utils import get_normalized_adj\n",
    "from torch.utils.data import DataLoader\n",
    "from stgraph_trainer.models import ProposedSTGNN\n",
    "from stgraph_trainer.trainers import ProposedSTGNNTrainer\n",
    "import torch\n",
    "import numpy as np\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load and process dataset\n",
    "### Setup parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "SPLIT_DATE = '2020-10-20'\n",
    "TIME_STEPS = 7\n",
    "PROVINCES = [\n",
    "  \"Seoul\",\n",
    "  \"Busan\",\n",
    "  \"Daegu\",\n",
    "  \"Incheon\",\n",
    "  \"Gwangju\",\n",
    "  \"Daejeon\",\n",
    "  \"Ulsan\",\n",
    "  \"Sejong\",\n",
    "  \"Gyeonggi\",\n",
    "  \"Gangwon\",\n",
    "  \"Chungbuk\",\n",
    "  \"Chungnam\",\n",
    "  \"Jeonbuk\",\n",
    "  \"Jeonnam\",\n",
    "  \"Gyeongbuk\",\n",
    "  \"Gyeongnam\",\n",
    "  \"Jeju\"]\n",
    "STATUS = 'New'\n",
    "BATCH_SIZE = 16\n",
    "EPOCHS = 10\n",
    "device = torch.device('cuda', 0) if torch.cuda.is_available() else torch.device('cpu')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Temporal data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = load_province_temporal_data(provinces=PROVINCES, status=STATUS)\n",
    "\n",
    "X_train, y_train, X_test, y_test, _, _, scaler = preprocess_data_for_stgnn(df,\n",
    "                                                                           SPLIT_DATE,\n",
    "                                                                           TIME_STEPS)\n",
    "\n",
    "X_train = torch.tensor(X_train).unsqueeze(-1)\n",
    "y_train = torch.tensor(y_train).unsqueeze(-1)\n",
    "X_test = torch.tensor(X_test).unsqueeze(-1)\n",
    "y_test = torch.tensor(y_test).unsqueeze(-1)\n",
    "n_test_samples = len(y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([194, 17, 7, 1]),\n",
       " torch.Size([194, 17, 1]),\n",
       " torch.Size([85, 17, 7, 1]),\n",
       " torch.Size([85, 17, 1]))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train.shape, y_train.shape, X_test.shape, y_test.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dl = DataLoader(PairDataset(X_train, y_train),\n",
    "                      batch_size=BATCH_SIZE,\n",
    "                      shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Spatial data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([17, 17])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "province_coords = load_province_coordinates().values[:, 1:]\n",
    "\n",
    "dist_km = []\n",
    "for idx, c1 in enumerate(province_coords):\n",
    "  dist_km.append([get_distance_in_km_between_earth_coordinates(c1, c2) for c2 in province_coords])\n",
    "dist_mx = np.array(dist_km)\n",
    "\n",
    "adj_mx = get_adjacency_matrix(dist_mx)\n",
    "# Fix formatting\n",
    "adj_mx = adj_mx.astype(np.float32)\n",
    "\n",
    "adj_mx = get_normalized_adj(adj_mx)\n",
    "adj = torch.tensor(adj_mx)\n",
    "adj.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = ProposedSTGNN(n_nodes=adj.shape[0],\n",
    "                      time_steps=TIME_STEPS,\n",
    "                      predicted_time_steps=1,\n",
    "                      in_channels=X_train.shape[3],\n",
    "                      spatial_channels=32,\n",
    "                      spatial_hidden_channels=16,\n",
    "                      spatial_out_channels=16,\n",
    "                      out_channels=16,\n",
    "                      temporal_kernel=3,\n",
    "                      drop_rate=0.2).to(device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_func = torch.nn.MSELoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = ProposedSTGNNTrainer(model,\n",
    "                               train_dl,\n",
    "                               X_test,\n",
    "                               adj,\n",
    "                               scaler,\n",
    "                               loss_func,\n",
    "                               optimizer,\n",
    "                               device,\n",
    "                               callbacks=None,\n",
    "                               raw_test=df.iloc[-(n_test_samples + 1):].values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1; Elapsed time: 0.04953336715698242; Train loss: 0.969324; Test MSE: 405.618561; Test loss RMSE: 20.139974\n",
      "Epoch: 2; Elapsed time: 0.05053210258483887; Train loss: 0.965432; Test MSE: 396.617920; Test loss RMSE: 19.915269\n",
      "Epoch: 3; Elapsed time: 0.046793222427368164; Train loss: 0.912030; Test MSE: 387.234161; Test loss RMSE: 19.678266\n",
      "Epoch: 4; Elapsed time: 0.04871797561645508; Train loss: 0.897380; Test MSE: 378.395599; Test loss RMSE: 19.452393\n",
      "Epoch: 5; Elapsed time: 0.046854257583618164; Train loss: 0.911294; Test MSE: 367.558472; Test loss RMSE: 19.171815\n",
      "Epoch: 6; Elapsed time: 0.04424715042114258; Train loss: 0.920928; Test MSE: 360.120514; Test loss RMSE: 18.976842\n",
      "Epoch: 7; Elapsed time: 0.04700469970703125; Train loss: 0.782902; Test MSE: 347.656372; Test loss RMSE: 18.645546\n",
      "Epoch: 8; Elapsed time: 0.03743457794189453; Train loss: 0.837908; Test MSE: 343.262146; Test loss RMSE: 18.527335\n",
      "Epoch: 9; Elapsed time: 0.03305935859680176; Train loss: 0.770195; Test MSE: 339.106934; Test loss RMSE: 18.414856\n",
      "Epoch: 10; Elapsed time: 0.0323178768157959; Train loss: 0.763083; Test MSE: 333.940979; Test loss RMSE: 18.274052\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'epoch': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],\n",
       " 'train_loss': [0.9693244076692141,\n",
       "  0.9654320902549304,\n",
       "  0.9120301294785279,\n",
       "  0.8973803657751817,\n",
       "  0.9112942195855654,\n",
       "  0.9209284346837264,\n",
       "  0.7829024496559913,\n",
       "  0.8379080983308645,\n",
       "  0.7701948285102844,\n",
       "  0.7630832974727337],\n",
       " 'test_loss': [405.6185607910156,\n",
       "  396.617919921875,\n",
       "  387.2341613769531,\n",
       "  378.3955993652344,\n",
       "  367.5584716796875,\n",
       "  360.1205139160156,\n",
       "  347.6563720703125,\n",
       "  343.26214599609375,\n",
       "  339.10693359375,\n",
       "  333.94097900390625],\n",
       " 'elapsed_time': [0.04953336715698242,\n",
       "  0.05053210258483887,\n",
       "  0.046793222427368164,\n",
       "  0.04871797561645508,\n",
       "  0.046854257583618164,\n",
       "  0.04424715042114258,\n",
       "  0.04700469970703125,\n",
       "  0.03743457794189453,\n",
       "  0.03305935859680176,\n",
       "  0.0323178768157959]}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "history = trainer.train(EPOCHS)\n",
    "history"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions = trainer.predict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(85, 17)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictions.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Seoul</th>\n",
       "      <th>Busan</th>\n",
       "      <th>Daegu</th>\n",
       "      <th>Incheon</th>\n",
       "      <th>Gwangju</th>\n",
       "      <th>Daejeon</th>\n",
       "      <th>Ulsan</th>\n",
       "      <th>Sejong</th>\n",
       "      <th>Gyeonggi</th>\n",
       "      <th>Gangwon</th>\n",
       "      <th>Chungbuk</th>\n",
       "      <th>Chungnam</th>\n",
       "      <th>Jeonbuk</th>\n",
       "      <th>Jeonnam</th>\n",
       "      <th>Gyeongbuk</th>\n",
       "      <th>Gyeongnam</th>\n",
       "      <th>Jeju</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>2020-10-20</th>\n",
       "      <td>16.434998</td>\n",
       "      <td>11.846681</td>\n",
       "      <td>0.377825</td>\n",
       "      <td>4.163242</td>\n",
       "      <td>1.670218</td>\n",
       "      <td>1.803768</td>\n",
       "      <td>0.168991</td>\n",
       "      <td>0.068562</td>\n",
       "      <td>20.012554</td>\n",
       "      <td>1.249839</td>\n",
       "      <td>0.670640</td>\n",
       "      <td>1.644362</td>\n",
       "      <td>0.739950</td>\n",
       "      <td>0.236018</td>\n",
       "      <td>1.074349</td>\n",
       "      <td>0.749913</td>\n",
       "      <td>0.121054</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2020-10-21</th>\n",
       "      <td>14.778962</td>\n",
       "      <td>4.063126</td>\n",
       "      <td>0.471197</td>\n",
       "      <td>3.794794</td>\n",
       "      <td>0.788469</td>\n",
       "      <td>0.833758</td>\n",
       "      <td>0.237530</td>\n",
       "      <td>0.063441</td>\n",
       "      <td>27.574669</td>\n",
       "      <td>2.103812</td>\n",
       "      <td>1.267345</td>\n",
       "      <td>1.764697</td>\n",
       "      <td>0.235352</td>\n",
       "      <td>0.246873</td>\n",
       "      <td>1.405837</td>\n",
       "      <td>1.110403</td>\n",
       "      <td>0.122224</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2020-10-22</th>\n",
       "      <td>16.539724</td>\n",
       "      <td>9.555019</td>\n",
       "      <td>0.224903</td>\n",
       "      <td>5.158150</td>\n",
       "      <td>1.186094</td>\n",
       "      <td>2.749295</td>\n",
       "      <td>0.524874</td>\n",
       "      <td>0.074417</td>\n",
       "      <td>25.867645</td>\n",
       "      <td>2.229380</td>\n",
       "      <td>0.963057</td>\n",
       "      <td>0.399974</td>\n",
       "      <td>0.372635</td>\n",
       "      <td>0.566936</td>\n",
       "      <td>1.667660</td>\n",
       "      <td>0.328303</td>\n",
       "      <td>0.119438</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2020-10-23</th>\n",
       "      <td>15.653477</td>\n",
       "      <td>4.867497</td>\n",
       "      <td>1.289272</td>\n",
       "      <td>3.717338</td>\n",
       "      <td>0.871321</td>\n",
       "      <td>0.594761</td>\n",
       "      <td>0.173816</td>\n",
       "      <td>0.040523</td>\n",
       "      <td>46.749451</td>\n",
       "      <td>1.958138</td>\n",
       "      <td>0.394497</td>\n",
       "      <td>6.471957</td>\n",
       "      <td>1.178643</td>\n",
       "      <td>0.281737</td>\n",
       "      <td>1.365990</td>\n",
       "      <td>0.443597</td>\n",
       "      <td>0.122340</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2020-10-24</th>\n",
       "      <td>17.662086</td>\n",
       "      <td>2.868346</td>\n",
       "      <td>3.139606</td>\n",
       "      <td>4.416123</td>\n",
       "      <td>3.090621</td>\n",
       "      <td>0.420742</td>\n",
       "      <td>0.239853</td>\n",
       "      <td>0.053558</td>\n",
       "      <td>77.659363</td>\n",
       "      <td>1.519273</td>\n",
       "      <td>0.440844</td>\n",
       "      <td>3.489955</td>\n",
       "      <td>1.287953</td>\n",
       "      <td>0.712951</td>\n",
       "      <td>0.656340</td>\n",
       "      <td>0.388869</td>\n",
       "      <td>0.119428</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                Seoul      Busan     Daegu   Incheon   Gwangju   Daejeon  \\\n",
       "2020-10-20  16.434998  11.846681  0.377825  4.163242  1.670218  1.803768   \n",
       "2020-10-21  14.778962   4.063126  0.471197  3.794794  0.788469  0.833758   \n",
       "2020-10-22  16.539724   9.555019  0.224903  5.158150  1.186094  2.749295   \n",
       "2020-10-23  15.653477   4.867497  1.289272  3.717338  0.871321  0.594761   \n",
       "2020-10-24  17.662086   2.868346  3.139606  4.416123  3.090621  0.420742   \n",
       "\n",
       "               Ulsan    Sejong   Gyeonggi   Gangwon  Chungbuk  Chungnam  \\\n",
       "2020-10-20  0.168991  0.068562  20.012554  1.249839  0.670640  1.644362   \n",
       "2020-10-21  0.237530  0.063441  27.574669  2.103812  1.267345  1.764697   \n",
       "2020-10-22  0.524874  0.074417  25.867645  2.229380  0.963057  0.399974   \n",
       "2020-10-23  0.173816  0.040523  46.749451  1.958138  0.394497  6.471957   \n",
       "2020-10-24  0.239853  0.053558  77.659363  1.519273  0.440844  3.489955   \n",
       "\n",
       "             Jeonbuk   Jeonnam  Gyeongbuk  Gyeongnam      Jeju  \n",
       "2020-10-20  0.739950  0.236018   1.074349   0.749913  0.121054  \n",
       "2020-10-21  0.235352  0.246873   1.405837   1.110403  0.122224  \n",
       "2020-10-22  0.372635  0.566936   1.667660   0.328303  0.119438  \n",
       "2020-10-23  1.178643  0.281737   1.365990   0.443597  0.122340  \n",
       "2020-10-24  1.287953  0.712951   0.656340   0.388869  0.119428  "
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.DataFrame(predictions,\n",
    "             columns=PROVINCES,\n",
    "             index=df.iloc[-n_test_samples:].index).head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "18.27405273769909"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Compute RMSE of test dataset\n",
    "m, m_avg = compute_metrics(df.iloc[-n_test_samples:], predictions, metric='rmse')\n",
    "m_avg"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Torch Env",
   "language": "python",
   "name": "torch_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
